library(tidyverse)
library(haven)
library(texreg)
library(margins)
library(broom)
library(prediction)
library(dotwhisker)

outputs_path <- "../../outputs/replication/"

source("functions_exp_models.R")

source("texreg_variable_list.R")

ex <- read_rds("../../data/replication/experiment_data_tidied.rds") 

controls_vec <- c("treatment_w21",
                  "male_w21",
                  "white_british_w21", 
                  "age_w21",
                  "class_w21",
                  "education_w21")

dvars <- tibble(variables = "uk_covid_performance_w21_num",
                neat_names = "Overall") %>%
  add_row(variables = "covid_overall_gov_fault_w21",
          neat_names = "Overall resp") %>%
  add_row(variables = "death_toll_gov_fault_w21",
          neat_names = "Death toll") %>%
  add_row(variables = "vaccine_gov_fault_w21",
          neat_names = "Vaccine") %>%
  add_row(variables = "retro_handle_w21_num",
          neat_names = "Retro handle") %>%
  add_row(variables = "retro_handle_change_w21",
          neat_names = "Change in retro handle") 

# define model you want to present here. done as string to make sure right
# margins are chosen too
presented_model_name <- "robust_party"

# now all the other models that could be chosen, and which can be added to the 
# rope ladder plot
base_model <- map(dvars$variables, 
                  data = ex,
                  weights = "wt_new_W21",
                  interaction_term = "party_id_gov_op_w21",
                  ols_treatment_interaction)

no_controls_margins <- map2(
    map(base_model, function(x) x$no_controls),
    dvars$variables,
    data = ex,
    get_margins_gov_op_intx) %>%
  bind_rows() %>%
  mutate(spec = "No controls")
with_controls_margins <- map2(
    map(base_model, function(x) x$with_controls),
    dvars$variables,
    data = ex,
    get_margins_gov_op_intx) %>%
  bind_rows() %>%
  #mutate(spec = "+ demographics")
  mutate(spec = "Base model")

pol_att_weighting <- map(dvars$variables, 
                         data = ex,
                         weights = "wt_with_pol_att_w21",
                         interaction_term = "party_id_gov_op_w21",
                         ols_treatment_interaction)
pol_att_weighting_margins <- map2(
    map(pol_att_weighting, function(x) x$with_controls),
    dvars$variables,
    data = ex,
    get_margins_gov_op_intx) %>%
  bind_rows() %>%
  mutate(spec = "+ weights for pol. attention")

robust_vaccine <- map(dvars$variables, 
                      data = (ex %>%
                        filter(rm_for_vaccine_attitude != 1)),
                      weights = "wt_with_pol_att_w21",
                      interaction_term = "party_id_gov_op_w21",
                      ols_treatment_interaction)
robust_vaccine_margins <- map2(
    map(robust_vaccine, function(x) x$with_controls),
    dvars$variables,
    data = ex,
    get_margins_gov_op_intx) %>%
  bind_rows() %>%
  mutate(spec = "+ remove vaccine hesitant")

robust_party <- map(dvars$variables, 
                    data = (ex %>%
                      filter(rm_for_party != 1) %>%
                      filter(rm_for_vaccine_attitude != 1)),
                    weights = "wt_with_pol_att_w21",
                    interaction_term = "party_id_gov_op_w21",
                    ols_treatment_interaction)
robust_party_margins <- map2(
    map(robust_party, function(x) x$with_controls),
    dvars$variables,
    data = ex,
    get_margins_gov_op_intx) %>%
  bind_rows() %>%
  mutate(spec = "+ remove by party") 

just_con_lab <- map(dvars$variables, 
                    data = (ex %>%
                      filter(rm_for_vaccine_attitude != 1) %>%
                      filter(partyIdW20 %in% c(1, 2, 10))),
                    weights = "wt_with_pol_att_w21",
                    interaction_term = "party_id_gov_op_w21",
                    ols_treatment_interaction)
just_con_lab_margins <- map2(
    map(just_con_lab, function(x) x$with_controls),
    dvars$variables,
    data = ex,
    get_margins_gov_op_intx) %>%
  bind_rows() %>%
  mutate(spec = "+ remove all but Con / Lab / None")

all_margins <- bind_rows(
  with_controls_margins,
  pol_att_weighting_margins,
  robust_vaccine_margins,
  robust_party_margins,
  just_con_lab_margins) %>%
  mutate(facet = factor(term, levels = c("Government partisan",
                                         "Opposition partisan"))) %>%
  mutate(term = spec) %>%
  arrange(facet) 

# defining function in file since might want different coef maps or 
# scale boxes for different models
custom_texreg <- function(list_of_models, ...){
  customised_reg <- generic_tidy_texreg(list_of_models,
    sideways = FALSE,
    custom.note = ("\\normalsize{%stars}"),
    custom.coef.map = texreg_variable_list,
    scalebox = 0.6,
    ...)

  return(customised_reg)
}

# get the right models based on choice defined at top of document
presented_model <- get(presented_model_name)
results_with_controls <- map(presented_model, function(x) x$with_controls)

caption_suffix <- paste("Results from OLS regressions of all dependent",
                        "variables on treatment status",
                        "interacted with party ID")

custom_texreg(results_with_controls,
       custom.header = list("\\textit{Dependent variable:}" = 
                            1:(length(dvars$neat_names))),
       custom.model.names = dvars$neat_names,
       caption = paste0(caption_suffix, ", with full controls added"),
       label = "tab: gov_op_intx_controls",
       groups = list("Treatment: (ref = Control)" = 1:2,
                     "Partisanship: (ref = Opposition)" = 3:4,
                     "Age: (ref = Under 18)" = 11:13,
                     "Class: (ref = Higher Managerial)" = 14:18,
                     "Education: (ref = Undergrad or Higher)" = 19:22),
       file = paste0(outputs_path, "experiment_with_controls.tex"))

# get the right margins based on choice defined at top of document
presented_model_margins <- get(paste0(presented_model_name, "_margins")) %>%
  rename(term = model,
         model = term) %>%
  # just manually highlighting the significant differences
  # screenreg(results_with_controls) 
  # dvars$neat_names                           
  mutate(signif_diff = case_when(
   model == "Non-partisan" ~ "No Signif. Difference",
   dvar == "uk_covid_performance_w21_num" & 
     term == "Positive treatment" ~ "Signif. Difference",
   dvar == "covid_overall_gov_fault_w21" & 
     term == "Negative treatment" ~ "Signif. Difference",
   dvar == "retro_handle_w21_num" & 
     term == "Positive treatment" ~ "Signif. Difference",
   dvar == "retro_handle_change_w21" & 
     term == "Positive treatment" ~ "Signif. Difference",
   TRUE ~ "No Signif. Difference"))

# Get better neat names and order vars 
with_order <- presented_model_margins %>% 
  mutate(neat_names = factor(neat_names, levels = dvars$neat_names)) %>%
  mutate(neat_names = factor(case_when(
      neat_names == "Overall" ~ "Overall performance", 
      neat_names == "Overall resp" ~ "Responsibility (overall)", 
      neat_names == "Death toll" ~ "Responsibility (death toll)", 
      neat_names == "Vaccine" ~ "Responsibility (vaccine rollout)", 
      neat_names == "Retro handle" ~ "Retro handle", 
      neat_names == "Change in retro handle" ~ "Change in retro handle"), 
    levels = c("Overall performance",
               "Responsibility (overall)",
               "Responsibility (death toll)",
               "Responsibility (vaccine rollout)",
               "Retro handle",
               "Change in retro handle"))) %>%
  mutate(term = factor(term, levels = c("Negative treatment",
                                        "Positive treatment"))) %>%
  filter(dvar != "death_toll_gov_fault_w21") %>%
  filter(dvar != "vaccine_gov_fault_w21") %>%
  mutate(model = case_when(
    model == "Government partisan" ~ "Government",
    model == "Opposition partisan" ~ "Opposition",
    model == "Non-partisan" ~ "None"))

all_plot <- dwplot(with_order,
                   dodge_size = 0.80,
                   vline = geom_vline(xintercept = 0, 
                                      colour = "grey60", 
                                      linetype = 2),
                   dot_args = list(aes(colour = model, 
                                       shape = signif_diff))) + 
  theme_bw() +
  theme(legend.position = "bottom") +
  guides(color = guide_legend(reverse = TRUE),
         shape = "none") +
         #shape = guide_legend(reverse=TRUE)) +
  scale_colour_manual(values = c("grey", "#E4003B", "#0087DC"), 
                      name = "Partisanship") +
  labs(colour = "Partisanship") +
      # shape = "Significant Interaction") +
  #scale_shape_manual(values = c(21,24)) +
  xlab("Conditional average marginal treatment effect") + 
  facet_wrap(~neat_names, nrow = 3) 

ggsave(all_plot,
       height = 5,
       width = 6,
       units = "in",
       file = paste0(outputs_path, "experiment_all_plotted", ".pdf"))

ggsave(all_plot,
       height = 5,
       width = 6,
       units = "in",
       file = paste0(outputs_path, "experiment_all_plotted", ".png"))

# Robustness plot
neat_all_margins <- all_margins %>%
  mutate(neat_names = factor(case_when(
      neat_names == "Overall" ~ "Overall performance", 
      neat_names == "Overall resp" ~ "Responsibility (overall)", 
      neat_names == "Death toll" ~ "Responsibility (death toll)", 
      neat_names == "Vaccine" ~ "Responsibility (vaccine rollout)", 
      neat_names == "Retro handle" ~ "Retro handle", 
      neat_names == "Change in retro handle" ~ "Change in retro handle"), 
    levels = c("Overall performance",
               "Responsibility (overall)",
               "Responsibility (death toll)",
               "Responsibility (vaccine rollout)",
               "Retro handle",
               "Change in retro handle"))) %>%
  filter(dvar != "death_toll_gov_fault_w21") %>%
  filter(dvar != "vaccine_gov_fault_w21") %>%
  mutate(facet = case_when(
    facet == "Government partisan" ~ "Government",
    facet == "Opposition partisan" ~ "Opposition",
    facet == "Non-partisan" ~ "None")) %>%
  # Get rid of non-partisans or it's too crowded on plot
  filter(facet != "None")

pos <- neat_all_margins %>% filter(model == "Positive treatment") %>%
  select(-model) %>%
  rename(model = facet)
robustness_plot_positive <- dwplot(pos,
                                   dodge_size = 0.80,
                                   vline = geom_vline(xintercept = 0, 
                                                      colour = "grey60", 
                                                      linetype = 2),
                                   dot_args = list(aes(colour = model))) + 
  theme_bw() +
  theme(legend.position = "bottom") +
  guides(color = guide_legend(reverse = TRUE),
         shape = "none") +
         #shape = guide_legend(reverse=TRUE)) +
  scale_colour_manual(values = c("#E4003B", "#0087DC"), 
                      name = "Partisanship") +
  labs(colour = "Partisanship") +
      # shape = "Significant Interaction") +
  #scale_shape_manual(values = c(21,24)) +
  xlab("Conditional average marginal treatment effect of +ve treatment") + 
  facet_wrap(~neat_names, nrow = 2) 

pos_rect <- pos %>% group_by(model, dvar) %>%
    mutate(min_rect = min(estimate),
           max_rect = max(estimate)) 

final_pos_robust_plot <- robustness_plot_positive +
  geom_rect(data = pos_rect %>% filter(model == "Government"), 
            aes(xmin=min_rect, xmax=max_rect, ymin=-Inf, ymax=Inf), 
            alpha = 0.05, fill = '#0087DC') +
  geom_rect(data = pos_rect %>% filter(model == "Opposition"), 
            aes(xmin=min_rect, xmax=max_rect, ymin=-Inf, ymax=Inf), 
            alpha = 0.05, fill = "#E4003B") 

neg <- neat_all_margins %>% filter(model == "Negative treatment") %>%
  select(-model) %>%
  rename(model = facet)
robustness_plot_negative <- dwplot(neg,
                                   dodge_size = 0.80,
                                   vline = geom_vline(xintercept = 0, 
                                                      colour = "grey60", 
                                                      linetype = 2),
                                   dot_args = list(aes(colour = model))) + 
  theme_bw() +
  theme(legend.position = "bottom") +
  guides(color = guide_legend(reverse = TRUE),
         shape = "none") +
         #shape = guide_legend(reverse=TRUE)) +
  scale_colour_manual(values = c("#E4003B", "#0087DC"), 
                      name = "Partisanship") +
  labs(colour = "Partisanship") +
      # shape = "Significant Interaction") +
  #scale_shape_manual(values = c(21,24)) +
  xlab("Conditional average marginal treatment effect of -ve treatment") + 
  facet_wrap(~neat_names, nrow = 2) 

neg_rect <- neg %>% group_by(model, dvar) %>%
    mutate(min_rect = min(estimate),
           max_rect = max(estimate)) 

final_neg_robust_plot <- robustness_plot_negative +
  geom_rect(data = neg_rect %>% filter(model == "Government"), 
            aes(xmin=min_rect, xmax=max_rect, ymin=-Inf, ymax=Inf), 
            alpha = 0.05, fill = '#0087DC') +
  geom_rect(data = neg_rect %>% filter(model == "Opposition"), 
            aes(xmin=min_rect, xmax=max_rect, ymin=-Inf, ymax=Inf), 
            alpha = 0.05, fill = "#E4003B") 

ggsave(final_pos_robust_plot,
       height = 5,
       width = 6,
       units = "in",
       file = paste0(outputs_path, "pos_robust", ".pdf"))
ggsave(final_neg_robust_plot,
       height = 5,
       width = 6,
       units = "in",
       file = paste0(outputs_path, "neg_robust", ".pdf"))

